{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Minifloat and Groupwise quantization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows some practical use cases for minifloat and groupwise quantization.\n",
    "\n",
    "Brevitas supports a wide combination of float quantization, including the OCP and FNUZ FP8 standard.\n",
    "It is possible to define any combination of exponent/mantissa bitwidth, as well as exponent bias.\n",
    "\n",
    "Similarly, MX quantization is supported as general groupwise quantization on top of integer/minifloat datatype.\n",
    "This allows to any general groupwise quantization, including MXInt and MXFloat standards.\n",
    "\n",
    "This tutorial shows how to instantiate and use some of the most interesting quantizers for minifloat and groupwise quantization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Minifloat (FP8 and lower)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brevitas offers some pre-defined quantizers for  minifloat quantization, including OCP and FNUZ standards, which can be further customized according to the specific use case.\n",
    "The general naming structure for the quantizers is the following:\n",
    "\n",
    "`Fp\\<Bitwidth\\>\\<Standard\\>Weight\\<Scaling\\>Float`\n",
    "\n",
    "Where `Bitwidth` can be either empty or `e4m3`/`e5m2`, `Standard` can be empty or `OCP`/`FNUZ`, `Scaling` can be empty or `PerTensor`/`PerChannel`.\n",
    "\n",
    "If `Bitwidth` is empty, the user must set it with kwargs or by subclassing the quantizers. Once the bitwidth is defined, the correct values for inf/nan are automatically defined based on the `Standard`.\n",
    "If a non-valid OCP bitwidth is set (e.g., e6m1), then no inf/nan values will be selected and the corresponding quantizer is not standard-compliant.\n",
    "\n",
    "`Standard` allows to pick among the two main FP8 standard options; moreover, if not specified, Brevitas offers the possibility of doing minifloat quantization without necessarily reserving values for inf/nan representation.\n",
    "This allows to use the maximum available range, since often in quantization, values that exceed the quantization range saturate to maximum rather than going to inf/nan.\n",
    "FNUZ quantizers need to have `saturating=True`.\n",
    "\n",
    "The `Scaling` options defines whether the quantization is _scaled_ or _unscaled_.\n",
    "In the unscaled case, the scale factor for quantization is fixed to one, otherwise it can be set using any of the methods that Brevitas includes (e.g., statistics, learned, etc.)\n",
    "\n",
    "\n",
    "Please keep in mind that not all combinations of the above options might be pre-defined and this serves mostly as indications of what Brevitas supports.\n",
    "It is possible, following the same structure of the available quantizers, to define new ones that fit your needs.\n",
    "\n",
    "\n",
    "Similar considerations can be extended for activation quantization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-09T14:50:24.350445Z",
     "iopub.status.busy": "2025-05-09T14:50:24.349767Z",
     "iopub.status.idle": "2025-05-09T14:50:31.322338Z",
     "shell.execute_reply": "2025-05-09T14:50:31.319353Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/c10/core/TensorImpl.h:1758.)\n",
      "  return super(Tensor, self).rename(names)\n",
      "[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
     ]
    }
   ],
   "source": [
    "from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n",
    "from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight\n",
    "from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n",
    "from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act\n",
    "import brevitas.nn as qnn\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "from brevitas.quant_tensor import FloatQuantTensor\n",
    "\n",
    "class OCPFP8Weight(FpOCPWeightPerTensorFloat, Fp8e4m3Mixin):\n",
    "    pass\n",
    "\n",
    "\n",
    "class OCPFP8Act(FpOCPActPerTensorFloat, Fp8e4m3Mixin):\n",
    "    pass\n",
    "\n",
    "\n",
    "class FP8Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=OCPFP8Weight, input_quant=OCPFP8Act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "ocp_fp8_model = FP8Model()\n",
    "x = torch.randn(1, 32, 8, 8)\n",
    "ocp_fp8_model.eval()\n",
    "o = ocp_fp8_model(x)\n",
    "\n",
    "intermediate_input = ocp_fp8_model.conv.input_quant(x)\n",
    "assert isinstance(intermediate_input, FloatQuantTensor)\n",
    "assert isinstance(ocp_fp8_model.conv.quant_weight(), FloatQuantTensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Groupwise quantization (MXInt/MXFloat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Groupwise quantization is built on top of integer/minifloat quantization, with special considerations to accomodate for the groupwise scaling.\n",
    "\n",
    "Compared to Int/Float QuantTensor, the main difference of their groupwise equivalent is that value, scale, and zero_point are not direct attributes anymore but properties. The new attributes are value_, scale_, and zero_point_.\n",
    "\n",
    "The reason for this is shaping. When quantizing a tensor with shapes [O, I], where O is output channel and I is input channel, with groupsize k, groupwise quantization is normally represented as follow:\n",
    "\n",
    "- Tensor with shapes [O, k, I/k]\n",
    "- Scales with shapes [O, k, 1]\n",
    "- Zero point same as scale\n",
    "\n",
    "The alternative to this representation is to have all three tensors with shapes [O,I], with a massive increase in memory utilization, especially with QAT + gradients.\n",
    "\n",
    "The underscored attributes will have the compressed shapes, while the properties (non-underscored naming) will dynamically compute the expanded version of the property. This means:\n",
    "```python\n",
    "quant_tensor.scale_.shape\n",
    "# This will print [O, k, 1]\n",
    "quant_tensor.scale.shape\n",
    "# This will print [O, I]\n",
    "```\n",
    "\n",
    "With respect to pre-defined quantizers, Brevitas offers several Groupwise and MX options.\n",
    "The main difference between the two is that MX is restricted to group_size=32 and the scale factor must be a power-of-2.\n",
    "The user can override these settings but the corresponding output won't be MX compliant.\n",
    "\n",
    "Another difference is that MXFloat relies on the OCP format as underlying data type, while generic groupwise float relies on the non-standard minifloat representation explained above.\n",
    "\n",
    "Finally, the general groupwise scaling relies on float scales."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-09T14:50:31.546766Z",
     "iopub.status.busy": "2025-05-09T14:50:31.545457Z",
     "iopub.status.idle": "2025-05-09T14:50:31.622935Z",
     "shell.execute_reply": "2025-05-09T14:50:31.619864Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/quant/solver/act.py:134: UserWarning: Group dim is being selected assuming batched input. Using unbatched input will fail and requires manually specification of group_dim\n",
      "  warn(\n"
     ]
    }
   ],
   "source": [
    "from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
    "\n",
    "\n",
    "class MXFloat8Weight(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n",
    "    # The group dimension for the weights it is automatically identified based on the layer type\n",
    "    # If a new layer type is used, it can be manually specified\n",
    "    pass\n",
    "\n",
    "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
    "    # In layerwise quantization, groupdim is automatically determined\n",
    "    pass\n",
    "\n",
    "\n",
    "class MXModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "mx_model = MXModel()\n",
    "x = torch.randn(1, 32, 8, 8)\n",
    "mx_model.eval()\n",
    "o = mx_model(x)\n",
    "\n",
    "intermediate_input = mx_model.conv.input_quant(x)\n",
    "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)\n",
    "assert isinstance(mx_model.conv.quant_weight(), GroupwiseFloatQuantTensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the input channel dimension is not divisible by group size, padding will be applied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-09T14:50:31.635685Z",
     "iopub.status.busy": "2025-05-09T14:50:31.634449Z",
     "iopub.status.idle": "2025-05-09T14:50:31.759533Z",
     "shell.execute_reply": "2025-05-09T14:50:31.757310Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non padding weights shape torch.Size([64, 1, 8, 3, 3])\n",
      "Padded weights shape torch.Size([64, 1, 32, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "class MXFloat8WeightNoPadding(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n",
    "    # The group dimension for the weights it is automatically identified based on the layer type\n",
    "    # If a new layer type is used, it can be manually specified\n",
    "    group_size = 8\n",
    "\n",
    "class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
    "    # In layerwise quantization, groupdim is automatically determined\n",
    "    group_size = 8\n",
    "\n",
    "\n",
    "class MXModelNoPadding(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8WeightNoPadding, input_quant=MXFloat8ActNoPadding)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "class MXModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "mx_model_no_padding = MXModelNoPadding()\n",
    "mx_model = MXModel()\n",
    "# Make sure that the modules are the same\n",
    "mx_model_no_padding.load_state_dict(mx_model.state_dict())\n",
    "\n",
    "x = torch.randn(1, 8, 8, 8)\n",
    "mx_model.eval()\n",
    "mx_model_no_padding.eval()\n",
    "o_no_padding = mx_model_no_padding(x)\n",
    "o = mx_model(x)\n",
    "\n",
    "# The quant weight of the padded model is different from the non padding one\n",
    "print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value_.shape}\")\n",
    "print(f\"Padded weights shape {mx_model.conv.quant_weight().value_.shape}\")\n",
    "\n",
    "# However, results are still the same \n",
    "assert torch.allclose(o, o_no_padding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-09T14:50:31.785658Z",
     "iopub.status.busy": "2025-05-09T14:50:31.784798Z",
     "iopub.status.idle": "2025-05-09T14:50:31.821099Z",
     "shell.execute_reply": "2025-05-09T14:50:31.819239Z"
    }
   },
   "outputs": [],
   "source": [
    "from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE\n",
    "from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
    "\n",
    "\n",
    "class MXFloat8Weight(MXFloat8e4m3WeightMSE, Fp8e4m3Mixin):\n",
    "    # The group dimension for the weights it is automatically identified based on the layer type\n",
    "    # If a new layer type is used, it can be manually specified\n",
    "    pass\n",
    "\n",
    "class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
    "    # In layerwise quantization, groupdim is automatically determined\n",
    "    pass\n",
    "\n",
    "\n",
    "class MXModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "mx_model = MXModel()\n",
    "x = torch.randn(1, 32, 8, 8)\n",
    "mx_model.eval()\n",
    "o = mx_model(x)\n",
    "\n",
    "intermediate_input = mx_model.conv.input_quant(x)\n",
    "assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-09T14:50:31.832443Z",
     "iopub.status.busy": "2025-05-09T14:50:31.831002Z",
     "iopub.status.idle": "2025-05-09T14:50:31.863010Z",
     "shell.execute_reply": "2025-05-09T14:50:31.861188Z"
    }
   },
   "outputs": [],
   "source": [
    "from brevitas.quant_tensor import GroupwiseIntQuantTensor\n",
    "from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight\n",
    "from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act\n",
    "import torch.nn as nn\n",
    "import brevitas.nn as qnn\n",
    "import torch\n",
    "\n",
    "class MXInt8Weight(MXInt8Weight):\n",
    "    # The group dimension for the weights it is automatically identified based on the layer type\n",
    "    # If a new layer type is used, it can be manually specified\n",
    "    pass\n",
    "\n",
    "class MXInt8Act(MXInt8Act):\n",
    "    # In layerwise quantization, groupdim is automatically determined\n",
    "    pass\n",
    "\n",
    "class MXModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = qnn.QuantConv2d(32, 64, 3, weight_quant=MXInt8Weight, input_quant=MXInt8Act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.conv(x)\n",
    "\n",
    "mx_model = MXModel()\n",
    "x = torch.randn(1, 32, 8, 8)\n",
    "mx_model.eval()\n",
    "o = mx_model(x)\n",
    "\n",
    "intermediate_input = mx_model.conv.input_quant(x)\n",
    "assert isinstance(intermediate_input, GroupwiseIntQuantTensor)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brevitas_dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
